from math import ceil

from compare_bean.origin_data_bean.compare_event import MemoryEvent
from compare_bean.origin_data_bean.trace_event_bean import TraceEventBean
from utils.constant import Constant


class TorchOpNode:
    def __init__(self, event=TraceEventBean, parent_node=None):
        self._event = event
        self._parent_node = parent_node
        self._child_nodes = []
        self._kernel_list = []
        self._kernel_num = 0
        self._memory_allocated_list = []

    @property
    def start_time(self):
        return self._event.start_time

    @property
    def end_time(self):
        return self._event.end_time

    @property
    def name(self):
        return self._event.name

    @property
    def input_shape(self):
        return str(self._event.args.get("Input Dims", Constant.NA))

    @property
    def origin_input_shape(self):
        return self._event.args.get("Input Dims", Constant.NA)

    @property
    def input_type(self):
        return str(self._event.args.get("Input type", Constant.NA))

    @property
    def call_stack(self):
        return str(self._event.args.get("Call stack", Constant.NA))

    @property
    def parent(self):
        return self._parent_node

    @property
    def child_nodes(self):
        return self._child_nodes

    @property
    def kernel_list(self):
        return self._kernel_list

    @property
    def kernel_num(self):
        return self._kernel_num

    @property
    def memory_allocated(self):
        return self._memory_allocated_list

    def add_child_node(self, child_node):
        self._child_nodes.append(child_node)

    def set_kernel_list(self, kernel_list: list):
        if not kernel_list:
            return
        self._kernel_list.extend(kernel_list)
        kernel_num = len(kernel_list)
        cur_node = self
        while cur_node._parent_node:
            cur_node._kernel_num += kernel_num
            cur_node = cur_node._parent_node

    def set_memory_allocated(self, memory_allocated: MemoryEvent):
        self._memory_allocated_list.append(memory_allocated)

    def is_step_profiler(self) -> bool:
        return self.name.find("ProfilerStep#") != -1

    def get_op_info(self) -> list:
        return [self.name, self.input_shape, self.input_type, self.call_stack]
